import numpy as np
from environment import Environment
from data import Data

class Likelihood:

    def lnlike(self, theta, trials):
        return 1

    # makes sure the given observational noise is within its prior limits (non-negative)
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, theta):
        sdz = theta
        if sdz >= 0:
            return 0.0
        else:
            return -np.inf

    # calculates the probability of a given observation noise given the model and the data
    # by combining the probabilities of the likelihood and prior
    def lnprob(self, theta, trials):
        lp = self.lnprior(theta)
        if lp == -np.inf:
            return -np.inf
        return lp + self.lnlike(theta, trials)

class AvgLike(Likelihood):
    def lnlike(self, sdz, train_trials):
        # compares accuracy on average
        pred_acc = Environment.get_pred_acc(train_trials, sdz)
        real_acc = Environment.get_real_acc(train_trials)
        LnLike = ((pred_acc - real_acc)) ** 2

        LnLike = -.5 * LnLike
        return LnLike

class BernLike(Likelihood):
    def lnlike(self, sdz, train_trials):
        # compares accuracy on average
        pred_acc = Environment.get_pred_acc(train_trials, sdz, sdz)
        real_acc = Environment.get_real_acc(train_trials)
        LnLike = (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike

class WeightedLike(Likelihood):
    def lnlike(self, sdz, train_trials):

        trials_split = Data.split_session_direction(train_trials)

        LnLike = 0

        for session in trials_split.values():
            # compares accuracy on average
            if len(session) != 0: # TODO: you can prob take out this check later!! just for testing smth
                pred_acc = Environment.get_pred_acc(session, sdz)
                real_acc = Environment.get_real_acc(session)
                LnLike += len(session) * (pred_acc - real_acc) ** 2

        LnLike = -.5 * LnLike
        return LnLike


class WeightedBernLike(Likelihood):

    def lnlike(self, sdz, train_trials):
        trials_split = Data.split_session_direction(train_trials)

        LnLike = 0
        for session in trials_split.values():
            # compares accuracy on average
            if len(session) != 0:  # TODO: you can prob take out this check later!! just for testing smth
                pred_acc = Environment.get_pred_acc(session, sdz, sdz)
                real_acc = Environment.get_real_acc(session)

                if pred_acc == 1:
                    pred_acc = 1 - (10 ** -10)
                if pred_acc == 0:
                    pred_acc = 0 + (10 ** -10)

                LnLike += len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike

class SigzLike(Likelihood):

    def lnlike(self, sigz, trials):
        pred_acc = Environment.get_pred_acc(trials, sigz, sigz)
        real_acc = Environment.get_real_acc(trials)

        if pred_acc == 1:
            pred_acc = 1 - (10 ** -10)
        LnLike = (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike


class SdzLike:

    def lnlike(self, sigz, sigz_sub, trials):
        trials_split = Data.split_session_direction(trials)

        LnLike = 0
        for session in trials_split.values():
            # compares accuracy on average
            if len(session) != 0:  # TODO: you can prob take out this check later!! just for testing smth
                pred_acc = Environment.get_pred_acc(session, sigz, sigz_sub)
                real_acc = Environment.get_real_acc(session)

                if pred_acc == 1:
                    pred_acc = 1 - (10 ** -10)
                if pred_acc == 0:
                    pred_acc = 0 + (10 ** -10)

                LnLike += len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike

    # makes sure the given observational noise is within its prior limits (non-negative)
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, sig):
        if sig >= 0:
            return 0.0
        else:
            return -np.inf

    # calculates the probability of a given observation noise given the model and the data
    # by combining the probabilities of the likelihood and prior
    def lnprob(self, sigz, sigz_sub, trials):
        lp = self.lnprior(sigz_sub)
        if lp == -np.inf:
            return -np.inf
        return lp + self.lnlike(sigz, sigz_sub, trials)

class BiasLike(Likelihood):
    def lnlike(self, sigz, sigz_sub, bias, trials):
        trials_split = Data.split_session_direction(trials)

        LnLike = 0
        for session in trials_split.values():
            # compares accuracy on average
            if len(session) != 0:  # TODO: you can prob take out this check later!! just for testing smth
                # pred_acc = Environment.get_pred_acc_bias(session, sigz, sigz_sub, bias)
                pred_acc = Environment.get_pred_acc(session, sigz, sigz_sub, bias)
                real_acc = Environment.get_real_acc(session)

                if pred_acc == 1:
                    pred_acc = 1 - (10 ** -10)
                if pred_acc == 0:
                    pred_acc = 0 + (10 ** -10)

                LnLike += len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike

    # makes sure the given observational noise is within its prior limits
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, sig):
        return 0

    # calculates the probability of a given observation noise given the model and the data
    # by combining the probabilities of the likelihood and prior
    def lnprob(self, sigz, sigz_sub, bias, trials):
        lp = self.lnprior(sigz_sub)
        if lp == -np.inf:
            return -np.inf
        return lp + self.lnlike(sigz, sigz_sub, bias, trials)

class ConfAvgLike(Likelihood):

    def __init__(self, sdz, beta):
        self.sdz = sdz
        self.beta = beta

    def lnlike(self, conf_cutoff, train_trials):
        x = np.empty(len(train_trials))
        y = np.empty(len(train_trials))
        priors = []
        values = []
        for i, trial in enumerate(train_trials):
            x[i] = trial.stimulus
            y[i] = trial.conf
            priors.append(trial.prior)
            values.append(trial.value)

        pred_conf_rate = Environment.get_pred_conf(train_trials, conf_cutoff, self.sdz, self.beta)
        real_conf_rate = Environment.calc_conf(y)
        LnLike = (pred_conf_rate - real_conf_rate) ** 2
        LnLike = -.5 * LnLike
        return LnLike

    # makes sure the given observational noise is within its prior limits (between 0 and 1)
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, theta):
        conf_cutoff = theta
        if .1 <= conf_cutoff <= 1:
            return 0.0
        else:
            return -np.inf


class WeightedConf(ConfAvgLike):

    def __init__(self, sigz, sdz, beta, conf_type, bias=0):
        super().__init__(sdz, beta)
        self.conf_type = conf_type
        self.sigz = sigz
        self.bias = bias
        # conf types:
        # 0 --> planning as inference
        # 1 --> max probability of observation
        # 2 --> max posterior probability
        # 3 --> max posterior expected value

    def lnlike(self, conf_cutoff, train_trials):
        # gives us a dictionary with 14 keys (one for each session and direction combinations)
        trials_split = Data.split_session_direction(train_trials)

        LnLike = 0
        for session in trials_split.values():
            pred_conf_rate, real_conf_rate = self.conf_condition(conf_cutoff, session, self.sigz, self.sdz, self.beta, self.bias)
            if pred_conf_rate == 1:
                pred_conf_rate = pred_conf_rate - (10**-10)
            if pred_conf_rate == 0:
                pred_conf_rate = pred_conf_rate + (10 ** -10)
            # LnLike += len(session) * (pred_conf_rate - real_conf_rate) ** 2
            LnLike += len(session) * ((real_conf_rate * np.log(pred_conf_rate)) + ((1-real_conf_rate) * np.log(1 - pred_conf_rate)))
            #         len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        # LnLike = -.5 * LnLike
        return LnLike

    def conf_condition(self, conf_cutoff, trials_cond, sigz, sdz, beta, bias):
        if self.conf_type == 0:
            pred_conf_rate = Environment.get_pred_conf(trials_cond, conf_cutoff, sigz, sdz, beta, bias)
        else:
            pred_conf_rate = Environment.get_pred_conf_other(trials_cond, conf_cutoff, sigz, sdz, self.conf_type, bias)
        real_conf_rate = Environment.get_real_conf(trials_cond)
        # print(conf_cutoff, real_conf_rate, pred_conf_rate)  # TODO take out!!
        return pred_conf_rate, real_conf_rate

class WeightedConfBias(ConfAvgLike):

    def __init__(self, sigz, sdz, bias, beta, conf_type):
        super().__init__(sdz, beta)
        self.conf_type = conf_type
        self.sigz = sigz
        self.bias = bias
        # conf types:
        # 0 --> planning as inference
        # 1 --> max probability of observation
        # 2 --> max posterior probability
        # 3 --> max posterior expected value

    def lnlike(self, conf_cutoff, train_trials):
        # gives us a dictionary with 14 keys (one for each session and direction combinations)
        trials_split = Data.split_session_direction(train_trials)

        LnLike = 0
        for session in trials_split.values():
            pred_conf_rate, real_conf_rate = self.conf_condition(conf_cutoff, session, self.sigz, self.sdz, self.bais, self.beta)
            if pred_conf_rate == 1:
                pred_conf_rate = pred_conf_rate - (10**-10)
            if pred_conf_rate == 0:
                pred_conf_rate = pred_conf_rate + (10 ** -10)
            # LnLike += len(session) * (pred_conf_rate - real_conf_rate) ** 2
            LnLike += len(session) * ((real_conf_rate * np.log(pred_conf_rate)) + ((1-real_conf_rate) * np.log(1 - pred_conf_rate)))
            #         len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        # LnLike = -.5 * LnLike
        return LnLike

    def conf_condition(self, conf_cutoff, trials_cond, sigz, sdz, bias, beta):
        if self.conf_type == 0:
            pred_conf_rate = Environment.get_pred_conf(trials_cond, conf_cutoff, sigz, sdz, beta)
        else:
            pred_conf_rate = Environment.get_pred_conf_other(trials_cond, conf_cutoff, sigz, sdz, self.conf_type)
        real_conf_rate = Environment.get_real_conf(trials_cond)
        # print(conf_cutoff, real_conf_rate, pred_conf_rate)  # TODO take out!!
        return pred_conf_rate, real_conf_rate


class ConfBernLike(ConfAvgLike):
    def __init__(self, sdz, beta):
        super().__init__(sdz, beta)

    def lnlike(self, conf_cutoff, train_trials):
        pred_conf_rate = Environment.get_pred_conf(train_trials, conf_cutoff, self.sdz, self.beta)
        real_conf_rate = Environment.get_real_conf(train_trials)
        LnLike = real_conf_rate * np.log(pred_conf_rate) + (1 - real_conf_rate) * (np.log(1 - pred_conf_rate))
        return LnLike


class BetaLike:
    def lnlike(self, beta, sigz, sigz_sub, conf_cutoff, train_trials):
        trials_split = Data.split_session_direction(train_trials)
        LnLike = 0
        for session in trials_split.values():
            if len(session) != 0:
                pred_conf = Environment.get_pred_conf(session, conf_cutoff, sigz, sigz_sub, beta)
                real_conf = Environment.get_real_conf(session)
                if pred_conf == 1:
                    pred_conf = 1 - (10 ** -10)
                if pred_conf == 0:
                    pred_conf = 0 + (10 ** -10)

                LnLike += len(session) * (real_conf * np.log(pred_conf) + (1 - real_conf) * (np.log(1 - pred_conf)))
        return LnLike

    # makes sure the given observational noise is within its prior limits (non-negative)
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, theta):
        beta = theta
        if beta >= 0:
            return 0.0
        else:
            return -np.inf

# calculates the probability of a given observation noise given the model and the data
    # by combining the probabilities of the likelihood and prior
    def lnprob(self, beta, sigz, sigz_sub, conf_cutoff, trials):
        lp = self.lnprior(beta)
        if lp == -np.inf:
            return -np.inf

class BetaLike2:
    def lnlike(self, beta, sigz, sigz_sub, train_trials):
        trials_split = Data.split_session_direction(train_trials)
        LnLike = 0
        for session in trials_split.values():
            if len(session) != 0:
                pred_acc = Environment.get_pred_acc(session, sigz, sigz_sub, d_beta=beta)
                real_acc = Environment.get_real_acc(session)
                if pred_acc == 1:
                    pred_acc = 1 - (10 ** -10)
                if pred_acc == 0:
                    pred_acc = 0 + (10 ** -10)

                LnLike += len(session) * (real_acc * np.log(pred_acc) + (1 - real_acc) * (np.log(1 - pred_acc)))
        return LnLike

    # makes sure the given observational noise is within its prior limits (non-negative)
    # returns 0 if within limits, returns negative infinity if not
    def lnprior(self, theta):
        beta = theta
        if beta >= 0:
            return 0.0
        else:
            return -np.inf

# calculates the probability of a given observation noise given the model and the data
    # by combining the probabilities of the likelihood and prior
    def lnprob(self, beta, sigz, sigz_sub, conf_cutoff, trials):
        lp = self.lnprior(beta)
        if lp == -np.inf:
            return -np.inf
        return lp + self.lnlike(beta, sigz, sigz_sub, conf_cutoff, trials)

